import torch
import json



def call_gpt(input, model, model_tokenizer, device, stop=["\n\n"], max_tokens=70):
    terminators = [
        model_tokenizer.eos_token_id,
        model_tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    flag = True
    input_ids = model_tokenizer(input,return_tensors="pt").to(device)

    outputs = model.generate(
        **input_ids,
        max_new_tokens=max_tokens,
        eos_token_id=terminators,
        do_sample=False,
        num_beams=1
    )

    response = outputs[0][input_ids['input_ids'].shape[-1]:]
    response = model_tokenizer.decode(response, skip_special_tokens=True)
 
    candidate_returned = []
    for stop_token in stop:
        if len(response.split(stop_token)) > 1:
            candidate_returned.append(response.split(stop_token)[0])
            flag = False
    
    if len(candidate_returned) != 0:
        candidate_returned = sorted(candidate_returned,key=lambda x:len(x))
        final_gen = candidate_returned[0]
    else:
        final_gen = ""

    return final_gen, flag


def pokemqa_eval_loop(mquake_dataset, dataset_name, masking, tokenizer, classifier, device,
                      task_wokg_prompt, seq_cls, result_file_path, model, model_tokenizer):
    
    rand_list = mquake_dataset.get_randlist()
    infer_dataset = mquake_dataset.get_dataset()
    stop = ["Generated answer:","\n\n"]
    
    if dataset_name != 'CF-6334' and masking:
        new_facts = None
    else:
        # we don't need to get a new set of facts for each case. We can get it once at here.
        new_facts = set()
        
        for d in mquake_dataset.get_dataset():
            if d['case_id'] not in rand_list:
                continue
            for r in d["requested_rewrite"]:
                new_facts.add(f'{r["prompt"].format(r["subject"])} {r["target_new"]["str"]}')
        new_facts = list(new_facts)
        if not new_facts:
            new_facts = ["No relevant fact."]
    
    answer_dict = {}
    for i,d in enumerate(infer_dataset):

        if dataset_name != 'CF-6334' and masking:
            new_facts, _, _, _ = mquake_dataset.get_edits_without_contamination(rand_list, d)
            if not new_facts:
                new_facts = ["No relevant fact."]
        
        
        answer_dict[d['case_id']] = {'edited': d['case_id'] in rand_list, 'answers':[]}
        
        edits_batch = list(set(new_facts))
        with torch.no_grad():
            facts_input = tokenizer(edits_batch, padding=True, truncation=True, max_length=256, return_tensors='pt').to(device)
            embs_batch = classifier(**facts_input).last_hidden_state[:,0]

        for q in d["questions"]:
            found_ans = False
            prompt = task_wokg_prompt + "\n\nQuestion: "+ q

            for i in range(5):
                # prompt the model to identify the subquestion
                if i == 0:
                    gen,flag  = call_gpt(prompt, model, model_tokenizer, device, stop, 35)
                else:
                    gen,flag  = call_gpt(prompt, model, model_tokenizer, device, stop)
                
                if flag:
                    break
                last_sent = gen.strip().split('\n')[-1]
                # if final answer is there, get the answer and exit
                if last_sent.startswith('Final answer: '):
                    found_ans = True
                    ans = last_sent[len("Final answer: "):]
                    prompt = prompt + gen
                    break
                # otherwise, extract the generated subquestion
                if len(gen.strip().split('\n')) < 1 or len(gen.strip().split('\n')) > 3:
                    prompt = prompt + gen
                    break # failed case
                subquestion = gen.strip().split('\n')[-1]
                if not subquestion.startswith('Subquestion: '):
                    prompt = prompt + gen
                    break # failed case
                subquestion = subquestion[len("Subquestion: "):]
                
                # conflict detection
                with torch.no_grad():
                    subquestion_input = tokenizer(subquestion, padding=True, truncation=True, max_length=256, return_tensors='pt').to(device)
                    subquestion_emb = classifier(**subquestion_input).last_hidden_state[:,0]
                    
                log_prob = (embs_batch-subquestion_emb).norm(2,-1)
                log_prob = -log_prob**2

                prob = log_prob.exp()
    
                if prob.max() < 0.5:
                    # prompt = prompt + gen +'Intermediate answer:'
                    prompt = prompt + gen + "Generated answer"

                else:
                    idxs = prob >= 0.5
                    edits_mini = [edits_batch[i] for i in range(len(idxs)) if idxs[i]==True]

                    if len(edits_mini)>1:
                        input_batch = []
                        for can_edit in edits_mini:
                            input_batch.append(can_edit + tokenizer.sep_token + subquestion)
                        
                        with torch.no_grad():
                            batch_input = tokenizer(input_batch, padding=True, truncation=True, max_length=256, return_tensors="pt").to(device)
                            batch_logits  = seq_cls(**batch_input).logits
                            seq_prob = torch.softmax(batch_logits,-1)[:,0]
                        
                        if seq_prob.max() < 0.5:
                            prompt = prompt + gen + "Generated answer"
                        
                        else:
                            value ,index = seq_prob.max(0)

                            # prompt = prompt + gen[:-remove_length] + edip ts_batch[no][index] + '.\nIntermediate answer:'
                            prompt = prompt + gen + 'Generated answer: '+edits_mini[index] + '.'

                    else:
                        value , index = prob.max(0)
                        prompt = prompt + gen + 'Generated answer: '+edits_batch[index] + '.'

            if not found_ans:
                answer_dict[d['case_id']]['answers'].append(None)
            else:
                answer_dict[d['case_id']]['answers'].append(ans)
                
            print(answer_dict[d['case_id']]['answers'])
        
        print('==' * 50)
        
        with open(result_file_path, 'w') as fp:
            json.dump(answer_dict, fp, indent = 4)
            
    print('inference stop!')